[Aiter][ROCm] RMSNormGated+GroupedQuantFP8 fusion#40710
Conversation
|
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Code Review
This pull request introduces fusion support for RMSNormGated followed by FP8 group quantization on ROCm platforms using the aiter library. Key changes include the registration of a new fused custom operator, the implementation of a MatcherRMSNormGated class, and updates to the RocmAiterRMSNormQuantFusionPass to discover and fuse these patterns. Feedback focuses on critical safety issues regarding the global monkey-patching of the pattern matcher's type handling, which could lead to incorrect matches for other operators. Additionally, improvements were suggested to ensure the gated fusion pattern correctly supports both aiter and decomposed quantization variants and strictly validates the supported group size of 128 to prevent numerical errors.
| _orig_fx_to_pat = pm.fx_to_pattern | ||
|
|
||
| def _relaxed_fx_to_pattern(*a, **kw): | ||
| kw["ignore_types"] = (int, torch.SymInt) | ||
| return _orig_fx_to_pat(*a, **kw) | ||
|
|
||
| pm.fx_to_pattern = _relaxed_fx_to_pattern | ||
| try: | ||
| self.matched_count = self.patterns.apply(graph) | ||
| finally: | ||
| pm.fx_to_pattern = _orig_fx_to_pat |
There was a problem hiding this comment.
Monkey-patching pm.fx_to_pattern to ignore all int and torch.SymInt types is extremely dangerous. This change affects all patterns registered in self.patterns, including those that rely on specific integer arguments for correctness (e.g., group_size=128 in AiterRMSFp8GroupQuantPattern). If a graph contains a quantization op with a different group size (e.g., 64), the matcher will incorrectly identify it as a match, leading to a replacement with a fused op that uses the wrong group size. This will cause silent numerical errors. A more targeted approach to handle SymInt in reshapes should be used instead of a global type ignore.
There was a problem hiding this comment.
I haven't found a better approach due to the shortcomings of the pytorch pattern matching based approach. This is becoming a common problem, especially when multiple reshapes exist.
5c39363 to
2c82404
Compare
2c82404 to
d4f1b17
Compare
31da8cb to
7b6683e
Compare
|
Some cleanup has been done and needs higher level feedback and a ready label to allow more complete testing. |
|
@gshtras Can you add the |
5753895 to
3307453
Compare
|
@tjtanaa seems to be the relevant CODEOWNER for this PR. |
f7fa464 to
a786d20
Compare
|
Hi @tpopp, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
1d1f946 to
232157b
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
232157b to
6881651
Compare
…y check Register fused_rms_gated_fp8_group_quant custom op that wraps the aiter Triton kernel for fused gated RMSNorm + FP8 group quantization. Also add are_gdn_triton_kernels_available() to check whether the required aiter Triton kernels (conv1d single-token, gated delta net) are importable, allowing graceful fallback on older aiter versions. Made-with: Cursor Signed-off-by: Tres Popp <tres.popp@amd.com>
Implement pattern matching and replacement for decomposed RMSNormGated followed by group FP8 quantization, fusing them into a single aiter Triton kernel (fused_rms_gated_fp8_group_quant). Key changes: - Add AiterRMSNormGatedFp8GroupQuantPattern in rocm_aiter_fusion.py that matches the decomposed norm+reshape+quant graph and replaces it with the fused op - Extend MatcherQuantFP8 and MatcherRMSNormGated in matcher_utils.py to support the gated norm pattern tracing - Add forward_static to RMSNormGated for code sharing with the matcher and have forward_native delegate to it - Simplify input_quant_fp8.py by extracting shared logic into forward_static - Dynamically infer num_heads/head_dim from GatedDeltaNetAttention layers via static_forward_context - Register per-token dynamic quant patterns for both aiter and non-aiter quant ops to handle +/- quant_fp8 configurations - Gate the gated pattern on are_gdn_triton_kernels_available() - Add unit tests for the fusion pattern (positive and negative cases) Made-with: Cursor Signed-off-by: Tres Popp <tres.popp@amd.com>
- Remove unused MatcherFusedAddRMSNorm and its dead imports (RMSNorm, RMS_ADD_OP) - Move fold_consecutive_reshapes to vllm_inductor_pass.py next to the related _fx_view_to_reshape helper - Add docstrings to new _aiter_ops methods (fused_rms_gated_fp8_group_quant impl and getter) - Check fused_rms_gated_fp8_group_quant importability in are_gdn_triton_kernels_available - Restore docstring on RMSNormGated.forward_native Signed-off-by: Tres Popp <trespopp@gmail.com> Signed-off-by: Tres Popp <tres.popp@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: Tres Popp <tres.popp@amd.com>
Iterate over use_triton for group quant patterns so both the CK and triton backends are matched. Use a set to deduplicate when quant_fp8 is disabled (forward_native is identical for both use_triton values). Add a head_dim == 128 guard to AiterRMSNormGatedFp8GroupQuantPattern since the fused kernel hardcodes group_size=head_dim. Rename _fx_view_to_reshape to fx_view_to_reshape as it is not private. Signed-off-by: Tres Popp <tres.popp@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: Tres Popp <tres.popp@amd.com>
6881651 to
fe5b13f
Compare
|
Hi @tpopp, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
fe5b13f to
141e8c5
Compare
The triton vs CK group quant op selection was added speculatively but the approved PR vllm-project#41825 uses only the aiter (CK) group quant op. Align the pattern matching with that decision. Signed-off-by: Tres Popp <tres.popp@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com>
141e8c5 to
380c7ce
Compare
Revert the fx_view_to_reshape rename since the function already exists upstream with the underscore prefix. Only apply the ignore_types monkey-patch for the pattern matcher when gated norm patterns are actually registered, avoiding interference with existing per-token and per-tensor fusion patterns. Signed-off-by: Tres Popp <tres.popp@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: Tres Popp <tres.popp@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com>
|
This pull request has merge conflicts that must be resolved before it can be |
The gated RMSNorm + group FP8 quant pattern matches when the quant op traces through native code (-quant_fp8) rather than the custom op. Remove ops_in_model_before since the pre-fusion quant op depends on the custom_ops config. Signed-off-by: Tres Popp <tres.popp@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com>
4ecce31 to
20e4933
Compare
These parameters were unintentionally removed during earlier cleanup. They are needed by the existing non-gated pattern registration logic. Signed-off-by: Tres Popp <tres.popp@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com>
|
@tjtanaa Do you mind taking a look?
I have a plan for a vllm_ir related form of rms_norm_gated but would like to do that as a follow up, so it's precisely targeted, and so I can separately clarify some details over how the IR and pattern matching behaves when custom ops are enabled. |
|
I also haven't seen a better way to handle the Shape related data gathering. As far as I can tell, the reliance on constructing the same ops forces us to construct patterns with the exact constants derived, but pointers are welcomed if I'm wrong. |
Pass match_aiter_quant through to super().__init__ instead of creating a separate MatcherQuantFP8. The base class already creates the matcher with the correct quant key. Signed-off-by: Tres Popp <tres.popp@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com>
…-fusion Co-authored-by: Cursor <cursoragent@cursor.com> # Conflicts: # vllm/compilation/passes/fusion/matcher_utils.py # vllm/compilation/passes/fusion/rocm_aiter_fusion.py Signed-off-by: Tres Popp <tres.popp@amd.com>
This PR adds a compilation fusion pass (AiterRMSNormGatedFp8GroupQuantPattern) that fuses the decomposed RMSNormGated + reshape + group FP8 quantization sequence into a single AITER Triton kernel call (fused_rms_gated_fp8_group_quant). This pattern appears in GatedDeltaNetAttention layers (e.g., Qwen3-Next) where each attention head's output goes through gated RMS normalization, is reshaped back to the full hidden dimension, and then group-quantized to FP8 before the output projection linear layer.
Results:
a 9us set of 2 kernels can be combined to 4.5us. In the case of Qwen3Next, this can be a 1-3% improvement depending on how small the workload is (concurrency 1 vs 128).
Motivation
In models using GatedDeltaNetAttention (such as Qwen3-Next-80B-A3B-Instruct-FP8), the output path of each attention block performs:
These three operations decompose into many elementwise and reduction kernels when torch.compile lowers them. By matching this pattern in the FX graph and replacing it with a single fused Triton kernel from AITER, we eliminate multiple GPU kernel launches and intermediate memory traffic.
Changes
• Register rocm_aiter_fused_rms_gated_fp8_group_quant custom op wrapping aiter.ops.triton.quant.fused_rms_gated_fp8_group_quant
• Add rocm_aiter_ops.are_gdn_triton_kernels_available() — checks whether the required AITER Triton kernels (causal_conv1d_update_single_token, gated_delta_net) are importable, allowing graceful fallback on older AITER builds that lack the GDN kernels
• rocm_aiter_fusion.py: Add AiterRMSNormGatedFp8GroupQuantPattern that matches the decomposed norm→reshape→quant graph and replaces it with the fused op. Add _fold_consecutive_reshapes pre-processing pass (needed because make_fx faithfully
records chained reshapes that must be folded for the pattern to match). Dynamically infer num_heads/head_dim from GatedDeltaNetAttention layers via static_forward_context. Gate the pattern on are_gdn_triton_kernels_available()
• matcher_utils.py: Add MatcherRMSNormGated pattern tracer that traces RMSNormGated.forward_static for use in pm.register_replacement. Extend MatcherQuantFP8 to support Triton-based quant op matching
• layernorm.py: Extract RMSNormGated.forward_static as a @staticmethod so both forward_native and the matcher can share the same pure-PyTorch implementation. forward_native delegates to it
• test_fusion.py: Add unit tests (TestGatedModel) for the fusion pattern covering positive match cases (aiter quant, non-aiter quant, per-token dynamic) and negative cases (wrong group shape, per-tensor quant)
AITER Dependency
The fused Triton kernel (fused_rms_gated_fp8_group_quant) is provided by ROCm/aiter#2423 (https://github.com/ROCm/aiter/pull/2423) ("[Triton] optimized decode kernels for Qwen3-Next model"). The fusion pass is gated behind rocm_aiter_ops.are_gdn_triton_kernels_available(), so it is a no-op on AITER versions that do not include this PR.
Benchmark Results
Setup:
• Model: Qwen/Qwen3-Next-80B-A3B-Instruct-FP8, TP=1
• GPU: AMD MI355x (gfx950), single GPU
• Base image: vllm/vllm-openai-rocm:nightly (vLLM v0.19.2rc1) with AITER rebuilt from aiter:main + PR #2423
• Attention backend: ROCM_AITER_FA
• Compilation: cudagraph_mode=FULL_AND_PIECEWISE, custom_ops=["-rms_norm", "-silu_and_mul", "+quant_fp8"], pass_config={"fuse_norm_quant": true}
• Benchmark command: vllm bench serve --dataset_name random --random_input_len 1024 --random_output_len 1024 --max_concurrency 4 --num_prompts 32 --num_warmups 4 --seed 1 --temperature 0 --ignore_eos
Pattern matching verification:
• With fusion: RocmAiterRMSNormQuantFusionPass replaced 5 patterns (1+2+2 across repeated-layer subgraphs — the 4 additional matches are from AiterRMSNormGatedFp8GroupQuantPattern)
• Without fusion (pattern commented out): replaced 1 pattern (only the existing non-gated AiterRMSNormDynamicQuantPattern)
Throughput (ISL=1024, OSL=1024, concurrency=4):
┌─────────────────────────────────┬─────────────┬──────────┬───────┐
│ Metric │ With Fusion │ Baseline │ Delta │
├─────────────────────────────────┼─────────────┼──────────┼───────┤
│ Output token throughput (tok/s) │ 467.05 │ 456.52 │ +2.3% │
│ Total token throughput (tok/s) │ 934.11 │ 913.04 │ +2.3% │
│ Mean TPOT (ms) │ 8.44 │ 8.66 │ −2.5% │
│ P99 TPOT (ms) │ 8.67 │ 8.98 │ −3.5% │
│ Mean E2EL (ms) │ 8,769 │ 8,971 │ −2.3% │
└─────────────────────────────────┴─────────────┴──────────┴───────┘
Accuracy (lm_eval, gsm8k, 5-shot):
┌──────────────────┬────────────────┬────────────────┬─────────────────────────────┐
│ Filter │ With Fusion │ Baseline │ Delta │
├──────────────────┼────────────────┼────────────────┼─────────────────────────────┤
│ flexible-extract │ 0.8605 ±0.0095 │ 0.8506 ±0.0098 │ +0.0099 (within error bars) │
│ strict-match │ 0.8089 ±0.0108 │ 0.8097 ±0.0108 │ −0.0008 (within error bars) │
└──────────────────┴────────────────┴────────────────┴─────────────────────────────┘
Accuracy is statistically identical — the fusion is numerically safe.
Test plan
• [x] Unit tests: pytest tests/compile/passes/test_fusion.py -k "gated" — positive and negative pattern match cases
• [x] lm_eval --tasks gsm8k --num_fewshot 5 — accuracy unchanged vs. baseline
• [x] vllm bench serve — throughput improved ~2.3%, TPOT improved ~2.5%
• [x] Verified graceful no-op when AITER lacks GDN kernels (are_gdn_triton_kernels_available() == False)